Evaluation of CycleGAN models¶
The quality of our results varied greatly between training runs and different selected training images. We chose our best model for the evaluation. Since we no longer have a ground truth for comparison, as in our previous experiments on black-and-white translation, we will conduct a qualitative assessment of the generated images through visual inspection. For evaluation, we used 64 images taken with a Sony digital camera. We edit these images with both the generator G to translate them into the look of Kodak Portra 400 and F to translate them into the look of Cinestill 800T.
InĀ [25]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
from tqdm import tqdm
from PIL import Image
from IPython.display import display
import os
import matplotlib.pyplot as plt
from torchvision import transforms
import torchvision.utils as vutils
import torch.nn.functional as F
InĀ [3]:
learning_rate = 1e-4
# Beta1 hyperparameter for Adam optimizers
beta1 = 0.5
batch_size = 128
# Checkpoint file paths
experiment = "cinestill_800t_portra_400_wo_Identity_loss"
check_G_A2B = "G_A2B.pth.tar"
check_G_B2A = "G_B2A.pth.tar"
check_D_A = "D_A.pth.tar"
check_D_B = "D_B.pth.tar"
InĀ [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
InĀ [29]:
class ResNetLUTGenerator(nn.Module):
def __init__(self, lut_size=8):
super(ResNetLUTGenerator, self).__init__()
self.lut_size = lut_size
self.cnn = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1)
)
self.fc = nn.Linear(128, lut_size * lut_size * lut_size * 3) # Output LUT values
def _trilinear_interpolation(self,luts, images):
# Scale images to be between -1 and 1
img = (images - 0.5) * 2.0 # Shape: [4, 3, 256, 256]
# Add a singleton dimension to represent "channel" for LUT interpolation
img = img.permute(0, 2, 3, 1).unsqueeze(1) # Shape: [4, 1, 256, 256, 3]
# Ensure LUT is in the correct format
LUT = luts.permute(0, 4, 1, 2, 3) # Shape: [4, 3, 8, 8, 8]
# Perform grid sampling for each channel
result = F.grid_sample(LUT, img, mode='bilinear', padding_mode='border', align_corners=True) # [4, 3, 1, 256, 256]
# Remove the singleton dimension and permute to the correct format
result = result.squeeze(2) # Remove the extra "depth" dimension: [4, 3, 256, 256]
return result
def _simple_approach(self, luts, images):
lut = luts.view(luts.shape[0], luts.shape[2], luts.shape[2], luts.shape[2], 3)
image_normalized = (images * luts.shape[2]-1).long()
image_normalized = torch.clamp(image_normalized, 0, luts.shape[2]-1)
r = image_normalized[:, 0, :, :]
g = image_normalized[:, 1, :, :]
b = image_normalized[:, 2, :, :]
transformed = lut[torch.arange(luts.shape[0]).unsqueeze(-1).unsqueeze(-1), r, g, b]
transformed = transformed.permute(0, 3, 1, 2)
return transformed
def forward(self, x):
x_orig = x
batch_size = x.size(0)
features = self.cnn(x)
features = features.view(features.size(0), -1) # Flatten
lut = self.fc(features)
lut = lut.view(-1, self.lut_size, self.lut_size, self.lut_size, 3)
lut = torch.sigmoid(lut)
transformed = self._simple_approach(lut, x_orig)
return transformed, lut
InĀ [30]:
netG_A2B = ResNetLUTGenerator(lut_size=33).to(device)
netG_B2A = ResNetLUTGenerator(lut_size=33).to(device)
InĀ [31]:
opt_gen = optim.Adam(
list(netG_A2B.parameters()) + list(netG_B2A.parameters()),
lr=learning_rate,
betas=(0.5, 0.999),
)
Load generators¶
InĀ [32]:
def load_checkpoint(model, optimizer, experiment, filepath):
print("=> Loading checkpoint")
filepath = f"models/{experiment}/{filepath}"
checkpoint = torch.load(filepath, weights_only=True)
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
return model, optimizer
InĀ [33]:
genA2B, _ = load_checkpoint(netG_A2B, opt_gen, experiment, check_G_A2B)
genB2A, _ = load_checkpoint(netG_B2A, opt_gen, experiment, check_G_B2A)
=> Loading checkpoint => Loading checkpoint
Get Sony test images¶
InĀ [34]:
class ValDataset(Dataset):
def __init__(self, val_images, transform=None):
self.val_images = val_images
self.transform = transform
def __len__(self):
return len(self.val_images)
def __getitem__(self, idx):
val_image = self.val_images[idx]
if self.transform:
val_image = self.transform(val_image)
return val_image
InĀ [35]:
transform64 = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
])
InĀ [17]:
transform = transforms.Compose([transforms.ToTensor()])
InĀ [18]:
def load_images_from_directory(directory_path, num_images=100):
images = []
files = os.listdir(directory_path)
# Ensure only images are processed
image_files = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
for i, image_file in enumerate(image_files[:num_images]):
image_path = os.path.join(directory_path, image_file)
img = Image.open(image_path).convert('RGB') # Convert to RGB in case of grayscale
img = transform64(img) # Apply the transformations
images.append(img)
images_tensor = torch.stack(images)
images_tensor = images_tensor.permute(0, 2, 3, 1)
return images_tensor
InĀ [20]:
val_images = load_images_from_directory('../../sony_images', num_images=128)
InĀ [21]:
val_dataset = ValDataset(val_images.numpy(), transform)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
InĀ [22]:
ev_images = next(iter(val_dataloader))
ev_images = ev_images.to(device)
Visualization of the original test images¶
InĀ [23]:
plt.figure(figsize=(16,16))
plt.axis("off")
plt.title("Test Images")
plt.imshow(np.transpose(vutils.make_grid(ev_images.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
plt.show()
InĀ [41]:
genA2B.eval()
genB2A.eval()
genA2B.to(device)
genB2A.to(device)
# Test of image transfer in kodak portra
fake_kodak_portra, _ = genA2B(ev_images)
# Test of image transfer in kodak gold
fake_cinestill, _ = genB2A(ev_images)
Visualization of the edited images translated to Kodak Portra 400¶
InĀ [37]:
plt.figure(figsize=(16,15))
plt.axis("off")
plt.title("Fake Portra Images")
plt.imshow(np.transpose(vutils.make_grid(fake_kodak_portra.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
plt.show()
Visualization of the edited images translated to Cinestill 800T¶
InĀ [42]:
plt.figure(figsize=(16,16))
plt.axis("off")
plt.title("Fake Cinestill Images")
plt.imshow(np.transpose(vutils.make_grid(fake_cinestill.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
plt.show()
Visualization of a single image¶
InĀ [40]:
def display_images(images, titles=None):
# Three subplots in one row
fig, axes = plt.subplots(1, 3, figsize=(9, 5))
# Convert and display each image
for idx, (ax, img) in enumerate(zip(axes, images)):
# Convert tensor to numpy array
tensor_image = img.detach().cpu()
image_np = np.transpose(tensor_image.numpy(), (1, 2, 0))
image_np = (image_np * 255).clip(0, 255).astype(np.uint8)
# Display the image
ax.imshow(image_np)
ax.axis('off')
# Set title if provided
if titles and idx < len(titles):
ax.set_title(titles[idx])
# Adjust layout to prevent overlap
plt.tight_layout()
plt.show()
InĀ [60]:
img_id = 9
display_images([ev_images[img_id], fake_kodak_portra[img_id], fake_cinestill[img_id]], ["Original Sony","Fake Kodak Portra", "Fake Cinestill"])